import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rc('xtick', labelsize=20) 
matplotlib.rc('ytick', labelsize=20) 
import os

import seaborn as sns
sns.set()

def get_mean_std(ress):
    return np.mean(ress, axis=0), np.std(ress, axis =0)


if __name__ == '__main__':    
    dataset = ['covertype', 'MagicTelescope','shuttle','mushroom','fashion']
    # dataset = ['covertype']
    dataset = ['cos','quad']

    for d in dataset:
        x = range(5000)
        plt.figure(figsize=(10, 6))


        ad = np.load("./results/ConservativeSquareCB_{}-new.npy".format(d))
        ad_mean, ad_std = get_mean_std(ad)
        plt.plot(x, ad_mean,color = 'red', linewidth=2.0, label = "Conservative-SquareCB")
        plt.fill_between(x, ad_mean-ad_std, ad_mean+ad_std, facecolor='red', alpha=0.2)

        ad = np.load("./results/old_results/ConservativeSquareCB_{}.npy".format(d))
        ad_mean, ad_std = get_mean_std(ad)
        plt.plot(x, ad_mean,color = 'green', linewidth=2.0, label = "Conservative-SquareCB")
        plt.fill_between(x, ad_mean-ad_std, ad_mean+ad_std, facecolor='green', alpha=0.2)


        # ad = np.load("./results/ConservativeFastCB_results_{}-new.npy".format(d))
        # ad_mean, ad_std = get_mean_std(ad)
        # plt.plot(x, ad_mean,color = 'blue', linewidth=2.0, label = "Conservative-FastCB")
        # plt.fill_between(x, ad_mean-ad_std, ad_mean+ad_std, facecolor='blue', alpha=0.2)


        # adv2 = np.load("./results/ConservativeLinUCB_{}.npy".format(d))
        # adv2_mean, adv2_std = get_mean_std(adv2)
        # plt.plot(x, adv2_mean, color = 'green', linewidth=2.0, label='Conservative-LinUCB')
        # plt.fill_between(x, adv2_mean-adv2_std, adv2_mean+adv2_std, facecolor='green', alpha=0.2)
    


        plt.xlabel('Rounds',fontsize=20)
        plt.ylabel('Regret',fontsize=20)
        plt.legend(prop={"size":20})
        #plt.legend()
        plt.title(d,fontsize=20)
        #plt.rcParams["figure.figsize"] = (20, 10)
        path = os.getcwd()
        print(path)
        plt.show()
        plt.savefig('{}/figures/regret_{}-trial.jpg'.format(path,d), dpi=500,bbox_inches='tight')
